{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "#skip\n",
    "! [ -e /content ] && pip install -Uqq fastai  # upgrade fastai on colab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "from fastai.basics import *\n",
    "from fastai.text.core import *\n",
    "from fastai.text.data import *\n",
    "from fastai.text.models.core import *\n",
    "from fastai.text.models.awdlstm import *\n",
    "from fastai.callback.rnn import *\n",
    "from fastai.callback.progress import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "from nbdev.showdoc import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#default_exp text.learner"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Learner for the text application\n",
    "\n",
    "> All the functions necessary to build `Learner` suitable for transfer learning in NLP"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The most important functions of this module are `language_model_learner` and `text_classifier_learner`. They will help you define a `Learner` using a pretrained model. See the [text tutorial](http://docs.fast.ai/tutorial.text) for exmaples of use."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loading a pretrained model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In text, to load a pretrained model, we need to adapt the embeddings of the vocabulary used for the pre-training to the vocabulary of our current corpus."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def match_embeds(old_wgts, old_vocab, new_vocab):\n",
    "    \"Convert the embedding in `old_wgts` to go from `old_vocab` to `new_vocab`.\"\n",
    "    bias, wgts = old_wgts.get('1.decoder.bias', None), old_wgts['0.encoder.weight']\n",
    "    wgts_m = wgts.mean(0)\n",
    "    new_wgts = wgts.new_zeros((len(new_vocab),wgts.size(1)))\n",
    "    if bias is not None:\n",
    "        bias_m = bias.mean(0)\n",
    "        new_bias = bias.new_zeros((len(new_vocab),))\n",
    "    old_o2i = old_vocab.o2i if hasattr(old_vocab, 'o2i') else {w:i for i,w in enumerate(old_vocab)}\n",
    "    for i,w in enumerate(new_vocab):\n",
    "        idx = old_o2i.get(w, -1)\n",
    "        new_wgts[i] = wgts[idx] if idx>=0 else wgts_m\n",
    "        if bias is not None: new_bias[i] = bias[idx] if idx>=0 else bias_m\n",
    "    old_wgts['0.encoder.weight'] = new_wgts\n",
    "    if '0.encoder_dp.emb.weight' in old_wgts: old_wgts['0.encoder_dp.emb.weight'] = new_wgts.clone()\n",
    "    old_wgts['1.decoder.weight'] = new_wgts.clone()\n",
    "    if bias is not None: old_wgts['1.decoder.bias'] = new_bias\n",
    "    return old_wgts"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For words in `new_vocab` that don't have a corresponding match in `old_vocab`, we use the mean of all pretrained embeddings. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "wgts = {'0.encoder.weight': torch.randn(5,3)}\n",
    "new_wgts = match_embeds(wgts.copy(), ['a', 'b', 'c'], ['a', 'c', 'd', 'b'])\n",
    "old,new = wgts['0.encoder.weight'],new_wgts['0.encoder.weight']\n",
    "test_eq(new[0], old[0])\n",
    "test_eq(new[1], old[2])\n",
    "test_eq(new[2], old.mean(0))\n",
    "test_eq(new[3], old[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "#With bias\n",
    "wgts = {'0.encoder.weight': torch.randn(5,3), '1.decoder.bias': torch.randn(5)}\n",
    "new_wgts = match_embeds(wgts.copy(), ['a', 'b', 'c'], ['a', 'c', 'd', 'b'])\n",
    "old_w,new_w = wgts['0.encoder.weight'],new_wgts['0.encoder.weight']\n",
    "old_b,new_b = wgts['1.decoder.bias'],  new_wgts['1.decoder.bias']\n",
    "test_eq(new_w[0], old_w[0])\n",
    "test_eq(new_w[1], old_w[2])\n",
    "test_eq(new_w[2], old_w.mean(0))\n",
    "test_eq(new_w[3], old_w[1])\n",
    "test_eq(new_b[0], old_b[0])\n",
    "test_eq(new_b[1], old_b[2])\n",
    "test_eq(new_b[2], old_b.mean(0))\n",
    "test_eq(new_b[3], old_b[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def _get_text_vocab(dls):\n",
    "    vocab = dls.vocab\n",
    "    if isinstance(vocab, L): vocab = vocab[0]\n",
    "    return vocab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def load_ignore_keys(model, wgts):\n",
    "    \"Load `wgts` in `model` ignoring the names of the keys, just taking parameters in order\"\n",
    "    sd = model.state_dict()\n",
    "    for k1,k2 in zip(sd.keys(), wgts.keys()): sd[k1].data = wgts[k2].data.clone()\n",
    "    return model.load_state_dict(sd)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def _rm_module(n):\n",
    "    t = n.split('.')\n",
    "    for i in range(len(t)-1, -1, -1):\n",
    "        if t[i] == 'module':\n",
    "            t.pop(i)\n",
    "            break\n",
    "    return '.'.join(t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "#For previous versions compatibility, remove for release\n",
    "def clean_raw_keys(wgts):\n",
    "    keys = list(wgts.keys())\n",
    "    for k in keys:\n",
    "        t = k.split('.module')\n",
    "        if f'{_rm_module(k)}_raw' in keys: del wgts[k]\n",
    "    return wgts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "#For previous versions compatibility, remove for release\n",
    "def load_model_text(file, model, opt, with_opt=None, device=None, strict=True):\n",
    "    \"Load `model` from `file` along with `opt` (if available, and if `with_opt`)\"\n",
    "    distrib_barrier()\n",
    "    if isinstance(device, int): device = torch.device('cuda', device)\n",
    "    elif device is None: device = 'cpu'\n",
    "    state = torch.load(file, map_location=device)\n",
    "    hasopt = set(state)=={'model', 'opt'}\n",
    "    model_state = state['model'] if hasopt else state\n",
    "    get_model(model).load_state_dict(clean_raw_keys(model_state), strict=strict)\n",
    "    if hasopt and ifnone(with_opt,True):\n",
    "        try: opt.load_state_dict(state['opt'])\n",
    "        except:\n",
    "            if with_opt: warn(\"Could not load the optimizer state.\")\n",
    "    elif with_opt: warn(\"Saved filed doesn't contain an optimizer state.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@delegates(Learner.__init__)\n",
    "class TextLearner(Learner):\n",
    "    \"Basic class for a `Learner` in NLP.\"\n",
    "    def __init__(self, dls, model, alpha=2., beta=1., moms=(0.8,0.7,0.8), **kwargs):\n",
    "        super().__init__(dls, model, moms=moms, **kwargs)\n",
    "        self.add_cbs([ModelResetter(), RNNRegularizer(alpha=alpha, beta=beta)])\n",
    "\n",
    "    def save_encoder(self, file):\n",
    "        \"Save the encoder to `file` in the model directory\"\n",
    "        if rank_distrib(): return # don't save if child proc\n",
    "        encoder = get_model(self.model)[0]\n",
    "        if hasattr(encoder, 'module'): encoder = encoder.module\n",
    "        torch.save(encoder.state_dict(), join_path_file(file, self.path/self.model_dir, ext='.pth'))\n",
    "\n",
    "    def load_encoder(self, file, device=None):\n",
    "        \"Load the encoder `file` from the model directory, optionally ensuring it's on `device`\"\n",
    "        encoder = get_model(self.model)[0]\n",
    "        if device is None: device = self.dls.device\n",
    "        if hasattr(encoder, 'module'): encoder = encoder.module\n",
    "        distrib_barrier()\n",
    "        wgts = torch.load(join_path_file(file,self.path/self.model_dir, ext='.pth'), map_location=device)\n",
    "        encoder.load_state_dict(clean_raw_keys(wgts))\n",
    "        self.freeze()\n",
    "        return self\n",
    "\n",
    "    def load_pretrained(self, wgts_fname, vocab_fname, model=None):\n",
    "        \"Load a pretrained model and adapt it to the data vocabulary.\"\n",
    "        old_vocab = load_pickle(vocab_fname)\n",
    "        new_vocab = _get_text_vocab(self.dls)\n",
    "        distrib_barrier()\n",
    "        wgts = torch.load(wgts_fname, map_location = lambda storage,loc: storage)\n",
    "        if 'model' in wgts: wgts = wgts['model'] #Just in case the pretrained model was saved with an optimizer\n",
    "        wgts = match_embeds(wgts, old_vocab, new_vocab)\n",
    "        load_ignore_keys(self.model if model is None else model, clean_raw_keys(wgts))\n",
    "        self.freeze()\n",
    "        return self\n",
    "\n",
    "    #For previous versions compatibility. Remove at release\n",
    "    @delegates(load_model_text)\n",
    "    def load(self, file, with_opt=None, device=None, **kwargs):\n",
    "        if device is None: device = self.dls.device\n",
    "        if self.opt is None: self.create_opt()\n",
    "        file = join_path_file(file, self.path/self.model_dir, ext='.pth')\n",
    "        load_model_text(file, self.model, self.opt, device=device, **kwargs)\n",
    "        return self"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Adds a `ModelResetter` and an `RNNRegularizer` with `alpha` and `beta` to the callbacks, the rest is the same as `Learner` init. \n",
    "\n",
    "This `Learner` adds functionality to the base class:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"TextLearner.load_pretrained\" class=\"doc_header\"><code>TextLearner.load_pretrained</code><a href=\"__main__.py#L28\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>TextLearner.load_pretrained</code>(**`wgts_fname`**, **`vocab_fname`**, **`model`**=*`None`*)\n",
       "\n",
       "Load a pretrained model and adapt it to the data vocabulary."
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(TextLearner.load_pretrained)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`wgts_fname` should point to the weights of the pretrained model and `vocab_fname` to the vocabulary used to pretrain it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"TextLearner.save_encoder\" class=\"doc_header\"><code>TextLearner.save_encoder</code><a href=\"__main__.py#L10\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>TextLearner.save_encoder</code>(**`file`**)\n",
       "\n",
       "Save the encoder to `file` in the model directory"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(TextLearner.save_encoder)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The model directory is `Learner.path/Learner.model_dir`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"TextLearner.load_encoder\" class=\"doc_header\"><code>TextLearner.load_encoder</code><a href=\"__main__.py#L17\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>TextLearner.load_encoder</code>(**`file`**, **`device`**=*`None`*)\n",
       "\n",
       "Load the encoder `file` from the model directory, optionally ensuring it's on `device`"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(TextLearner.load_encoder)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Language modeling predictions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For language modeling, the predict method is quite different form the other applications, which is why it needs its own subclass."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def decode_spec_tokens(tokens):\n",
    "    \"Decode the special tokens in `tokens`\"\n",
    "    new_toks,rule,arg = [],None,None\n",
    "    for t in tokens:\n",
    "        if t in [TK_MAJ, TK_UP, TK_REP, TK_WREP]: rule = t\n",
    "        elif rule is None: new_toks.append(t)\n",
    "        elif rule == TK_MAJ:\n",
    "            new_toks.append(t[:1].upper() + t[1:].lower())\n",
    "            rule = None\n",
    "        elif rule == TK_UP:\n",
    "            new_toks.append(t.upper())\n",
    "            rule = None\n",
    "        elif arg is None:\n",
    "            try:    arg = int(t)\n",
    "            except: rule = None\n",
    "        else:\n",
    "            if rule == TK_REP: new_toks.append(t * arg)\n",
    "            else:              new_toks += [t] * arg\n",
    "    return new_toks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(decode_spec_tokens(['xxmaj', 'text']), ['Text'])\n",
    "test_eq(decode_spec_tokens(['xxup', 'text']), ['TEXT'])\n",
    "test_eq(decode_spec_tokens(['xxrep', '3', 'a']), ['aaa'])\n",
    "test_eq(decode_spec_tokens(['xxwrep', '3', 'word']), ['word', 'word', 'word'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class LMLearner(TextLearner):\n",
    "    \"Add functionality to `TextLearner` when dealing with a language model\"\n",
    "    def predict(self, text, n_words=1, no_unk=True, temperature=1., min_p=None, no_bar=False,\n",
    "                decoder=decode_spec_tokens, only_last_word=False):\n",
    "        \"Return `text` and the `n_words` that come after\"\n",
    "        self.model.reset()\n",
    "        idxs = idxs_all = self.dls.test_dl([text]).items[0].to(self.dls.device)\n",
    "        if no_unk: unk_idx = self.dls.vocab.index(UNK)\n",
    "        for _ in (range(n_words) if no_bar else progress_bar(range(n_words), leave=False)):\n",
    "            with self.no_bar(): preds,_ = self.get_preds(dl=[(idxs[None],)])\n",
    "            res = preds[0][-1]\n",
    "            if no_unk: res[unk_idx] = 0.\n",
    "            if min_p is not None:\n",
    "                if (res >= min_p).float().sum() == 0:\n",
    "                    warn(f\"There is no item with probability >= {min_p}, try a lower value.\")\n",
    "                else: res[res < min_p] = 0.\n",
    "            if temperature != 1.: res.pow_(1 / temperature)\n",
    "            idx = torch.multinomial(res, 1).item()\n",
    "            idxs = idxs_all = torch.cat([idxs_all, idxs.new([idx])])\n",
    "            if only_last_word: idxs = idxs[-1][None]\n",
    "\n",
    "        num = self.dls.train_ds.numericalize\n",
    "        tokens = [num.vocab[i] for i in idxs_all if num.vocab[i] not in [BOS, PAD]]\n",
    "        sep = self.dls.train_ds.tokenizer.sep\n",
    "        return sep.join(decoder(tokens))\n",
    "\n",
    "    @delegates(Learner.get_preds)\n",
    "    def get_preds(self, concat_dim=1, **kwargs): return super().get_preds(concat_dim=1, **kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h3 id=\"LMLearner\" class=\"doc_header\"><code>class</code> <code>LMLearner</code><a href=\"\" class=\"source_link\" style=\"float:right\">[source]</a></h3>\n",
       "\n",
       "> <code>LMLearner</code>(**`dls`**, **`model`**, **`alpha`**=*`2.0`*, **`beta`**=*`1.0`*, **`moms`**=*`(0.8, 0.7, 0.8)`*, **`loss_func`**=*`None`*, **`opt_func`**=*`Adam`*, **`lr`**=*`0.001`*, **`splitter`**=*`trainable_params`*, **`cbs`**=*`None`*, **`metrics`**=*`None`*, **`path`**=*`None`*, **`model_dir`**=*`'models'`*, **`wd`**=*`None`*, **`wd_bn_bias`**=*`False`*, **`train_bn`**=*`True`*) :: [`TextLearner`](/text.learner.html#TextLearner)\n",
       "\n",
       "Add functionality to [`TextLearner`](/text.learner.html#TextLearner) when dealing with a language model"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(LMLearner, title_level=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"LMLearner.predict\" class=\"doc_header\"><code>LMLearner.predict</code><a href=\"__main__.py#L5\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>LMLearner.predict</code>(**`text`**, **`n_words`**=*`1`*, **`no_unk`**=*`True`*, **`temperature`**=*`1.0`*, **`min_p`**=*`None`*, **`no_bar`**=*`False`*, **`decoder`**=*`decode_spec_tokens`*, **`only_last_word`**=*`False`*)\n",
       "\n",
       "Return `text` and the `n_words` that come after"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(LMLearner.predict)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The words are picked randomly among the predictions, depending on the probability of each index. `no_unk` means we never pick the `UNK` token, `temperature` is applied to the predictions, if `min_p` is passed, we don't consider the indices with a probability lower than it. Set `no_bar` to `True` if you don't want any progress bar, and you can pass a long a custom `decoder` to process the predicted tokens."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## `Learner` convenience functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "from fastai.text.models.core import _model_meta"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def _get_text_vocab(dls):\n",
    "    vocab = dls.vocab\n",
    "    if isinstance(vocab, L): vocab = vocab[0]\n",
    "    return vocab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@delegates(Learner.__init__)\n",
    "def language_model_learner(dls, arch, config=None, drop_mult=1., backwards=False, pretrained=True, pretrained_fnames=None, **kwargs):\n",
    "    \"Create a `Learner` with a language model from `dls` and `arch`.\"\n",
    "    vocab = _get_text_vocab(dls)\n",
    "    model = get_language_model(arch, len(vocab), config=config, drop_mult=drop_mult)\n",
    "    meta = _model_meta[arch]\n",
    "    learn = LMLearner(dls, model, loss_func=CrossEntropyLossFlat(), splitter=meta['split_lm'], **kwargs)\n",
    "    url = 'url_bwd' if backwards else 'url'\n",
    "    if pretrained or pretrained_fnames:\n",
    "        if pretrained_fnames is not None:\n",
    "            fnames = [learn.path/learn.model_dir/f'{fn}.{ext}' for fn,ext in zip(pretrained_fnames, ['pth', 'pkl'])]\n",
    "        else:\n",
    "            if url not in meta:\n",
    "                warn(\"There are no pretrained weights for that architecture yet!\")\n",
    "                return learn\n",
    "            model_path = untar_data(meta[url] , c_key='model')\n",
    "            try: fnames = [list(model_path.glob(f'*.{ext}'))[0] for ext in ['pth', 'pkl']]\n",
    "            except IndexError: print(f'The model in {model_path} is incomplete, download again'); raise\n",
    "        learn = learn.load_pretrained(*fnames)\n",
    "    return learn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can use the `config` to customize the architecture used (change the values from `awd_lstm_lm_config` for this), `pretrained` will use fastai's pretrained model for this `arch` (if available) or you can pass specific `pretrained_fnames` containing your own pretrained model and the corresponding vocabulary. All other arguments are passed to `Learner`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/jhoward/anaconda3/lib/python3.7/site-packages/numpy/core/_asarray.py:83: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray\n",
      "  return array(a, dtype, copy=False, order=order)\n"
     ]
    }
   ],
   "source": [
    "path = untar_data(URLs.IMDB_SAMPLE)\n",
    "df = pd.read_csv(path/'texts.csv')\n",
    "dls = TextDataLoaders.from_df(df, path=path, text_col='text', is_lm=True, valid_col='is_valid')\n",
    "learn = language_model_learner(dls, AWD_LSTM)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can then use the `.predict` method to generate new text."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "'This movie is about the first generation of human USA inventor Deals J. Simpson who is at the realized science'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "learn.predict('This movie is about', n_words=20)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "By default the entire sentence is feed again to the model after each predicted word, this little trick shows an improvement on the quality of the generated text. If you want to feed only the last word, specify argument `only_last_word`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "'This movie is about three - go in The Creating and rest of the next . Carol . Grotesque and'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "learn.predict('This movie is about', n_words=20, only_last_word=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@delegates(Learner.__init__)\n",
    "def text_classifier_learner(dls, arch, seq_len=72, config=None, backwards=False, pretrained=True, drop_mult=0.5, n_out=None,\n",
    "                            lin_ftrs=None, ps=None, max_len=72*20, y_range=None, **kwargs):\n",
    "    \"Create a `Learner` with a text classifier from `dls` and `arch`.\"\n",
    "    vocab = _get_text_vocab(dls)\n",
    "    if n_out is None: n_out = get_c(dls)\n",
    "    assert n_out, \"`n_out` is not defined, and could not be inferred from data, set `dls.c` or pass `n_out`\"\n",
    "    model = get_text_classifier(arch, len(vocab), n_out, seq_len=seq_len, config=config, y_range=y_range,\n",
    "                                drop_mult=drop_mult, lin_ftrs=lin_ftrs, ps=ps, max_len=max_len)\n",
    "    meta = _model_meta[arch]\n",
    "    learn = TextLearner(dls, model, splitter=meta['split_clas'], **kwargs)\n",
    "    url = 'url_bwd' if backwards else 'url'\n",
    "    if pretrained:\n",
    "        if url not in meta:\n",
    "            warn(\"There are no pretrained weights for that architecture yet!\")\n",
    "            return learn\n",
    "        model_path = untar_data(meta[url], c_key='model')\n",
    "        try: fnames = [list(model_path.glob(f'*.{ext}'))[0] for ext in ['pth', 'pkl']]\n",
    "        except IndexError: print(f'The model in {model_path} is incomplete, download again'); raise\n",
    "        learn = learn.load_pretrained(*fnames, model=learn.model[0])\n",
    "        learn.freeze()\n",
    "    return learn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can use the `config` to customize the architecture used (change the values from `awd_lstm_clas_config` for this), `pretrained` will use fastai's pretrained model for this `arch` (if available). `drop_mult` is a global multiplier applied to control all dropouts. `n_out` is usually inferred from the `dls` but you may pass it.\n",
    "\n",
    "The model uses a `SentenceEncoder`, which means the texts are passed `seq_len` tokens at a time, and will only compute the gradients on the last `max_len` steps. `lin_ftrs` and `ps` are passed to `get_text_classifier`.\n",
    "\n",
    "All other arguments are passed to `Learner`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/jhoward/anaconda3/lib/python3.7/site-packages/numpy/core/_asarray.py:83: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray\n",
      "  return array(a, dtype, copy=False, order=order)\n"
     ]
    }
   ],
   "source": [
    "path = untar_data(URLs.IMDB_SAMPLE)\n",
    "df = pd.read_csv(path/'texts.csv')\n",
    "dls = TextDataLoaders.from_df(df, path=path, text_col='text', label_col='label', valid_col='is_valid')\n",
    "learn = text_classifier_learner(dls, AWD_LSTM)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Show methods -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@typedispatch\n",
    "def show_results(x: LMTensorText, y, samples, outs, ctxs=None, max_n=10, **kwargs):\n",
    "    if ctxs is None: ctxs = get_empty_df(min(len(samples), max_n))\n",
    "    for i,l in enumerate(['input', 'target']):\n",
    "        ctxs = [b.show(ctx=c, label=l, **kwargs) for b,c,_ in zip(samples.itemgot(i),ctxs,range(max_n))]\n",
    "    ctxs = [b.show(ctx=c, label='pred', **kwargs) for b,c,_ in zip(outs.itemgot(0),ctxs,range(max_n))]\n",
    "    display_df(pd.DataFrame(ctxs))\n",
    "    return ctxs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@typedispatch\n",
    "def show_results(x: TensorText, y, samples, outs, ctxs=None, max_n=10, trunc_at=150, **kwargs):\n",
    "    if ctxs is None: ctxs = get_empty_df(min(len(samples), max_n))\n",
    "    samples = L((s[0].truncate(trunc_at),*s[1:]) for s in samples)\n",
    "    ctxs = show_results[object](x, y, samples, outs, ctxs=ctxs, max_n=max_n, **kwargs)\n",
    "    display_df(pd.DataFrame(ctxs))\n",
    "    return ctxs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@typedispatch\n",
    "def plot_top_losses(x: TensorText, y:TensorCategory, samples, outs, raws, losses, trunc_at=150, **kwargs):\n",
    "    rows = get_empty_df(len(samples))\n",
    "    samples = L((s[0].truncate(trunc_at),*s[1:]) for s in samples)\n",
    "    for i,l in enumerate(['input', 'target']):\n",
    "        rows = [b.show(ctx=c, label=l, **kwargs) for b,c in zip(samples.itemgot(i),rows)]\n",
    "    outs = L(o + (TitledFloat(r.max().item()), TitledFloat(l.item())) for o,r,l in zip(outs, raws, losses))\n",
    "    for i,l in enumerate(['predicted', 'probability', 'loss']):\n",
    "        rows = [b.show(ctx=c, label=l, **kwargs) for b,c in zip(outs.itemgot(i),rows)]\n",
    "    display_df(pd.DataFrame(rows))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Export -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Converted 00_torch_core.ipynb.\n",
      "Converted 01_layers.ipynb.\n",
      "Converted 02_data.load.ipynb.\n",
      "Converted 03_data.core.ipynb.\n",
      "Converted 04_data.external.ipynb.\n",
      "Converted 05_data.transforms.ipynb.\n",
      "Converted 06_data.block.ipynb.\n",
      "Converted 07_vision.core.ipynb.\n",
      "Converted 08_vision.data.ipynb.\n",
      "Converted 09_vision.augment.ipynb.\n",
      "Converted 09b_vision.utils.ipynb.\n",
      "Converted 09c_vision.widgets.ipynb.\n",
      "Converted 10_tutorial.pets.ipynb.\n",
      "Converted 11_vision.models.xresnet.ipynb.\n",
      "Converted 12_optimizer.ipynb.\n",
      "Converted 13_callback.core.ipynb.\n",
      "Converted 13a_learner.ipynb.\n",
      "Converted 13b_metrics.ipynb.\n",
      "Converted 14_callback.schedule.ipynb.\n",
      "Converted 14a_callback.data.ipynb.\n",
      "Converted 15_callback.hook.ipynb.\n",
      "Converted 15a_vision.models.unet.ipynb.\n",
      "Converted 16_callback.progress.ipynb.\n",
      "Converted 17_callback.tracker.ipynb.\n",
      "Converted 18_callback.fp16.ipynb.\n",
      "Converted 18a_callback.training.ipynb.\n",
      "Converted 19_callback.mixup.ipynb.\n",
      "Converted 20_interpret.ipynb.\n",
      "Converted 20a_distributed.ipynb.\n",
      "Converted 21_vision.learner.ipynb.\n",
      "Converted 22_tutorial.imagenette.ipynb.\n",
      "Converted 23_tutorial.vision.ipynb.\n",
      "Converted 24_tutorial.siamese.ipynb.\n",
      "Converted 24_vision.gan.ipynb.\n",
      "Converted 30_text.core.ipynb.\n",
      "Converted 31_text.data.ipynb.\n",
      "Converted 32_text.models.awdlstm.ipynb.\n",
      "Converted 33_text.models.core.ipynb.\n",
      "Converted 34_callback.rnn.ipynb.\n",
      "Converted 35_tutorial.wikitext.ipynb.\n",
      "Converted 36_text.models.qrnn.ipynb.\n",
      "Converted 37_text.learner.ipynb.\n",
      "Converted 38_tutorial.text.ipynb.\n",
      "Converted 39_tutorial.transformers.ipynb.\n",
      "Converted 40_tabular.core.ipynb.\n",
      "Converted 41_tabular.data.ipynb.\n",
      "Converted 42_tabular.model.ipynb.\n",
      "Converted 43_tabular.learner.ipynb.\n",
      "Converted 44_tutorial.tabular.ipynb.\n",
      "Converted 45_collab.ipynb.\n",
      "Converted 46_tutorial.collab.ipynb.\n",
      "Converted 50_tutorial.datablock.ipynb.\n",
      "Converted 60_medical.imaging.ipynb.\n",
      "Converted 61_tutorial.medical_imaging.ipynb.\n",
      "Converted 65_medical.text.ipynb.\n",
      "Converted 70_callback.wandb.ipynb.\n",
      "Converted 71_callback.tensorboard.ipynb.\n",
      "Converted 72_callback.neptune.ipynb.\n",
      "Converted 73_callback.captum.ipynb.\n",
      "Converted 74_callback.cutmix.ipynb.\n",
      "Converted 97_test_utils.ipynb.\n",
      "Converted 99_pytorch_doc.ipynb.\n",
      "Converted index.ipynb.\n",
      "Converted tutorial.ipynb.\n"
     ]
    }
   ],
   "source": [
    "#hide\n",
    "from nbdev.export import notebook2script\n",
    "notebook2script()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "jupytext": {
   "split_at_heading": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
